from Network.network_utils import pytorch_model, assign_distribution
import numpy as np
import torch
from tianshou.data import Batch

def aggregate_result(combined, new_result, step_num, combine_type="average", cast_numpy=False):
    # aggregates result batches of the same kind together.
    for k in new_result.keys():
        if k in combined:
            if type(k) == Batch:
                combined[k] = aggregate_result(combined[k], new_result[k])
            else:
                n = step_num + 1
                # print(k, new_result[k].shape, combined[k].shape)
                val = new_result[k]

                if cast_numpy and type(new_result[k]) == torch.Tensor: val = pytorch_model.unwrap(new_result[k])
                if cast_numpy and type(combined[k]) == torch.Tensor: combined[k] = pytorch_model.unwrap(combined[k])
                if combine_type == "average": combined[k] = combined[k] *  (n-1)/n +  val / n # assumes array-like and shapes match combined * (n-1)/n + new value /n
                elif combine_type in ["cat0", "catn1"]:
                    # TODO: this block is also an eyesore
                    if len(val.shape) == 0 and type(val) == torch.Tensor: val = val.unsqueeze(0)
                    if len(val.shape) == 0 and type(val) == np.ndarray: val = np.expand_dims(val, axis=0)
                    if len(combined[k].shape) == 0 and type(combined[k]) == torch.Tensor: combined[k] = combined[k].unsqueeze(0)
                    if len(combined[k].shape) == 0 and type(combined[k]) == np.ndarray: combined[k] = np.expand_dims(combined[k], axis=0)
                    if k == "omit_flags": cat_dim = 1 # TODO: this is a hacky fix but omit flags are on the second dim
                    elif combine_type == "cat0":
                        cat_dim=0  
                    elif combine_type == "catn1": 
                        cat_dim=-1
                    if type(combined[k]) == torch.Tensor: combined[k] = torch.cat((combined[k], val), dim=cat_dim)
                    if type(combined[k]) == np.ndarray: combined[k] = np.concatenate((combined[k], val), axis=cat_dim)
                else: raise NotImplementedError("combine type " + str(combine_type) +" not implemented, types are: average, cat0, catn1" )
        else:
            combined[k] = new_result[k]
    return combined

def filter_batch_names(combined, names):
    filtered_combined = Batch()
    kept_any = False
    for k in combined.keys():
        if k in names:
            filtered_combined[k] = combined[k]
            kept_any = True
        elif type(combined[k]) == Batch():
            res_batch, kept_any = filter_batch_names(combined[k], names)
            if kept_any: combined[k] = res_batch
    return filtered_combined, kept_any



def get_done_flags(batch, iscuda):
    return pytorch_model.wrap(1-batch.done, cuda = iscuda).squeeze().unsqueeze(-1)

def get_target(model, batch, predict_dynamics, name, iscuda=False):
    target = batch.target_diff if predict_dynamics else batch.next_target
    target = model.extractor.get_named_target(target, names=name)
    target = pytorch_model.wrap(target, cuda=iscuda)
    return target

def compute_dist(args, model, result, batch, name=""):
    iscuda = result.params.is_cuda
    done_flags = pytorch_model.wrap(1-batch.done, cuda = iscuda).squeeze().unsqueeze(-1)
    target = get_target(model, batch, args.inter.predict_dynamics, name, iscuda=iscuda)
    
    # compute the distributional difference
    dist = model.dists.forward(*result.params)
    log_probs = dist.log_prob(target)

    # adds target, dist, done_flags, log_probs
    result.target, result.dist, result.done_flags, result.log_probs = target, dist, done_flags, log_probs
    return result

def reshape_object(args, val):
    # takes in [batch, size] -> batch, num_objects, size
    # uses args.factor.single_obj_dim
    return val.reshape(val.shape[0], -1, args.factor.single_obj_dim)


def compute_likelihood(args, result, batch, model, name=""):
    '''
    computes likelihoods, zeroing out dones, from logits
    expects tensors, not necessarily cuda
    uses: result.logits, args.inter.predict_dynamics, batch.done, 
    Unused because we call _target_dists in inference_module
    '''
    iscuda = result.params.is_cuda
    result = compute_dist(args, model, result, batch, name=name)
    done_flags = get_done_flags(batch, iscuda)
    loss_log_probs = result.log_probs * done_flags # batch, num_obj*obj_dim 
    loss_log_probs = reshape_object(args, result.log_probs) # batch, num_obj, obj_dim
    result.loss_log_probs = loss_log_probs
    return result

def compute_expectile(log_probs, ord=2, tau=0.5, threshold=0):
    '''
    computes expectiles from log probs (assume positive is higher accuracy), either @param ord 1 or 2
    order 1 is the quantile.
    @param tau is the threshold value for the expectile
    @param threshold is hard to measure (should be estimated from data),
        but captures the threshold value to treat the quantile
    '''
    log_probs = log_probs.mean(dim=-1)
    return asymmetric_loss(log_probs - threshold, tau=tau, ord=ord)

def asymmetric_loss(u: torch.Tensor, tau: float, ord: float = 2) -> torch.Tensor:
    return torch.mean(torch.abs(tau - (u < 0).float()) * np.linalg.norm(u, ord = ord))

def compute_likelihood_adaptive_lasso(no_use_adaptive, base_value, adaptive_lasso, result, batch, baseline_likelihood, bias, flatten_factor, pointwise=False):
    if no_use_adaptive: return base_value
    else: 
        # done_flags = (1-batch.done).astype(int)
        loss_ll = result.log_probs
        if pointwise:
            difference = -np.abs(baseline_likelihood - 
                                bias - 
                                np.max((pytorch_model.unwrap(loss_ll.sum(dim=-1)), -10), axis=-1)
                                )
        else:
            difference = -np.abs(baseline_likelihood - 
                    bias - 
                    np.max((pytorch_model.unwrap(loss_ll.sum(dim=-1).mean()), -10))
                    )

        ratio = np.exp(difference / flatten_factor)
        # print(baseline_likelihood, bias, difference, np.exp(difference / flatten_factor), flatten_factor, pytorch_model.unwrap(loss_ll.sum(dim=-1).mean()), ratio, adaptive_lasso * ratio)
        return adaptive_lasso * ratio

def compute_mean_adaptive_lasso(use_adaptive, base_value, adaptive_lasso, result, batch, flatten_factor, pointwise=False):
    if use_adaptive: return base_value
    else:
        # done_flags = (1-batch.done).astype(int)
        mean, target = result.mean, batch.target
        mean_difference = pytorch_model.unwrap(torch.linalg.norm(mean - target, p=1, dim=-1))
        if pointwise: mean_difference = np.mean(mean_difference)
        ratio = np.exp(-mean_difference / flatten_factor)
        return adaptive_lasso * ratio

def compute_mean_var_adaptive_lasso(use_adaptive, base_value, adaptive_lasso, result, batch, flatten_factor, pointwise=False):
    if use_adaptive < 0: return base_value
    else:
        # done_flags = (1-batch.done).astype(int)
        mean, var, target = result.mean, result.var, batch.target
        mean_difference = pytorch_model.unwrap(torch.linalg.norm(mean - target, ord=1, dim=-1))
        confidence = pytorch_model.unwrap(torch.linalg.norm(var, ord=1, dim=-1))
        combined = (mean_difference + confidence)
        if pointwise: combined = np.mean(combined)
        ratio = np.exp(- combined/ flatten_factor)
        return adaptive_lasso * ratio

def compute_adaptive_rate(ada_type, no_use_adaptive, base_value, adaptive_lasso, result, batch, baseline_likelihood, bias, flatten_factor, pointwise=False):
    if ada_type == "likelihood": lasso_lambda = compute_likelihood_adaptive_lasso(no_use_adaptive, base_value, adaptive_lasso, result, batch, baseline_likelihood, bias, flatten_factor, pointwise=pointwise)
    elif ada_type == "mean": lasso_lambda = compute_mean_adaptive_lasso(no_use_adaptive, base_value, adaptive_lasso, result, batch, flatten_factor, pointwise=pointwise)
    elif ada_type == "meanvar": lasso_lambda = compute_mean_var_adaptive_lasso(no_use_adaptive, base_value, adaptive_lasso, result, batch, flatten_factor, pointwise=pointwise)
    return lasso_lambda